{
  "cells": [
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "Lf7huAiYp-An"
      },
      "source": [
        "##### Copyright 2019 The TensorFlow Authors."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 1,
      "metadata": {
        "cellView": "form",
        "id": "YHz2D-oIqBWa"
      },
      "outputs": [],
      "source": [
        "#@title Licensed under the Apache License, Version 2.0 (the \"License\");\n",
        "# you may not use this file except in compliance with the License.\n",
        "# You may obtain a copy of the License at\n",
        "#\n",
        "# https://www.apache.org/licenses/LICENSE-2.0\n",
        "#\n",
        "# Unless required by applicable law or agreed to in writing, software\n",
        "# distributed under the License is distributed on an \"AS IS\" BASIS,\n",
        "# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n",
        "# See the License for the specific language governing permissions and\n",
        "# limitations under the License."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "x44FFES-r6y0"
      },
      "source": [
        "# テキスト生成のフェデレーテッドラーニング"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "iPFgLeZIsZ3Q"
      },
      "source": [
        "<table class=\"tfo-notebook-buttons\" align=\"left\">\n",
        "  <td><a target=\"_blank\" href=\"https://www.tensorflow.org/federated/tutorials/federated_learning_for_text_generation\"><img src=\"https://www.tensorflow.org/images/tf_logo_32px.png\"> TensorFlow.orgで表示</a></td>\n",
        "  <td><a target=\"_blank\" href=\"https://colab.research.google.com/github/tensorflow/docs-l10n/blob/master/site/ja/federated/tutorials/federated_learning_for_text_generation.ipynb\"><img src=\"https://www.tensorflow.org/images/colab_logo_32px.png\"> Google Colab で実行</a></td>\n",
        "  <td><a target=\"_blank\" href=\"https://github.com/tensorflow/docs-l10n/blob/master/site/ja/federated/tutorials/federated_learning_for_text_generation.ipynb\"><img src=\"https://www.tensorflow.org/images/GitHub-Mark-32px.png\">GitHub でソースを表示{</a></td>\n",
        "</table>"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "KbNz2tuvsAFB"
      },
      "source": [
        "**注意**: この Colab は `tensorflow_federated` pip パッケージの[最新リリースバージョン](https://github.com/tensorflow/federated#compatibility)での動作が確認されていますが、Tensorflow Federated プロジェクトは現在もプレリリース開発の段階にあるため、`master` では動作しない可能性があります。\n",
        "\n",
        "このチュートリアルは、[画像分類のフェデレーテッドラーニング](federated_learning_for_image_classification.ipynb)チュートリアルの概念に基づいて構成されており、フェデレーテッドラーニングの便利なアプローチをいくつか実演します。\n",
        "\n",
        "具体的には、以前にトレーニングした Keras モデルを読み込み、（シミュレーションされた）分散データセットでフェデレーテッドラーニングを使ってそのモデルをさらに洗練します。これはいくつかの理由により特に重要な作業です。シリアル化されたモデルを使用できることで、フェデレーテッドラーニングをほかの機械学習アプローチに簡単に混ぜることができるようになります。さらに、広範なトレーニング済みのモデルを使用することも可能です。たとえば、トレーニング済みの言語モデルは広く提供されてるため（[TF Hub](https://www.tensorflow.org/hub) など）、モデルをゼロからトレーニングする必要はほとんどありません。そのため、トレーニング済みのモデルを開始点に、フェデレーテッドラーニングを使って洗練させ、特定のアプリケーションに使用する分散データセットの特性に合わせて調整する方が合理的と言えます。\n",
        "\n",
        "このチュートリアルでは、ASCII 文字を生成する RNN より開始し、フェデレーテッドラーニングを通じて精緻化します。また、最終的な重みを元の Keras モデルにフィードし直し、評価とテキスト生成を標準のツールを使って簡単に行う方法も紹介します。"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "9LcC1AwjoqfR"
      },
      "outputs": [],
      "source": [
        "#@test {\"skip\": true}\n",
        "!pip install --quiet --upgrade tensorflow_federated"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 3,
      "metadata": {
        "id": "ZjDQysatrc2S"
      },
      "outputs": [
        {
          "data": {
            "text/plain": [
              "b'Hello, World!'"
            ]
          },
          "execution_count": 3,
          "metadata": {
            "tags": []
          },
          "output_type": "execute_result"
        }
      ],
      "source": [
        "import collections\n",
        "import functools\n",
        "import os\n",
        "import time\n",
        "\n",
        "import numpy as np\n",
        "import tensorflow as tf\n",
        "import tensorflow_federated as tff\n",
        "\n",
        "np.random.seed(0)\n",
        "\n",
        "# Test the TFF is working:\n",
        "tff.federated_computation(lambda: 'Hello, World!')()"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "lyICXwVAxvW9"
      },
      "source": [
        "## トレーニング済みモデルを読み込む\n",
        "\n",
        "TensorFlow チュートリアル「[Eager execution を使った RNN によるテキスト生成](https://www.tensorflow.org/tutorials/sequences/text_generation)」に従ってトレーニングされたモデルを使用しますが、[The Complete Works of Shakespeare](http://www.gutenberg.org/files/100/100-0.txt) を使用する代わりに、チャールズ・ディケンズの「[A Tale of Two Cities](http://www.ibiblio.org/pub/docs/books/gutenberg/9/98/98.txt)」と「[A Christmas Carol](http://www.ibiblio.org/pub/docs/books/gutenberg/4/46/46.txt)」のテキストでモデルを事前トレーニングしています。\n",
        "\n",
        "語彙を拡大する以外は元のチュートリアルを変更していないため、初期モデルは最新の状態ではありませんが、合理的な予測を生成するものであり、このチュートリアルの目的には十分と言えます。最終モデルは `tf.keras.models.save_model(include_optimizer=False)` を使って保存されています。\n",
        "\n",
        "このチュートリアルでは、フェデレーテッドラーニングを使用して、このシェイクスピアのモデルを精緻化します。TFF が提供するフェデレーテッドバージョンのデータを使用します。\n"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "XgF8e2Ksyq1F"
      },
      "source": [
        "### vocab ルックアップテーブルの生成"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 4,
      "metadata": {
        "id": "IlCgQBRVymwR"
      },
      "outputs": [],
      "source": [
        "# A fixed vocabularly of ASCII chars that occur in the works of Shakespeare and Dickens:\n",
        "vocab = list('dhlptx@DHLPTX $(,048cgkoswCGKOSW[_#\\'/37;?bfjnrvzBFJNRVZ\"&*.26:\\naeimquyAEIMQUY]!%)-159\\r')\n",
        "\n",
        "# Creating a mapping from unique characters to indices\n",
        "char2idx = {u:i for i, u in enumerate(vocab)}\n",
        "idx2char = np.array(vocab)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "2EH6MFRdzAwd"
      },
      "source": [
        "### トレーニング済みモデルの読み込みとテキストの生成"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 5,
      "metadata": {
        "id": "iIK674SrtCTm"
      },
      "outputs": [],
      "source": [
        "def load_model(batch_size):\n",
        "  urls = {\n",
        "      1: 'https://storage.googleapis.com/tff-models-public/dickens_rnn.batch1.kerasmodel',\n",
        "      8: 'https://storage.googleapis.com/tff-models-public/dickens_rnn.batch8.kerasmodel'}\n",
        "  assert batch_size in urls, 'batch_size must be in ' + str(urls.keys())\n",
        "  url = urls[batch_size]\n",
        "  local_file = tf.keras.utils.get_file(os.path.basename(url), origin=url)  \n",
        "  return tf.keras.models.load_model(local_file, compile=False)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 6,
      "metadata": {
        "id": "WvuwZBX5Ogfd"
      },
      "outputs": [],
      "source": [
        "def generate_text(model, start_string):\n",
        "  # From https://www.tensorflow.org/tutorials/sequences/text_generation\n",
        "  num_generate = 200\n",
        "  input_eval = [char2idx[s] for s in start_string]\n",
        "  input_eval = tf.expand_dims(input_eval, 0)\n",
        "  text_generated = []\n",
        "  temperature = 1.0\n",
        "\n",
        "  model.reset_states()\n",
        "  for i in range(num_generate):\n",
        "    predictions = model(input_eval)\n",
        "    predictions = tf.squeeze(predictions, 0)\n",
        "    predictions = predictions / temperature\n",
        "    predicted_id = tf.random.categorical(\n",
        "        predictions, num_samples=1)[-1, 0].numpy()\n",
        "    input_eval = tf.expand_dims([predicted_id], 0)\n",
        "    text_generated.append(idx2char[predicted_id])\n",
        "\n",
        "  return (start_string + ''.join(text_generated))"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 7,
      "metadata": {
        "id": "MGAdStJ5wDPV"
      },
      "outputs": [
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "Downloading data from https://storage.googleapis.com/tff-models-public/dickens_rnn.batch1.kerasmodel\n",
            "16195584/16193984 [==============================] - 0s 0us/step\n",
            "16203776/16193984 [==============================] - 0s 0us/step\n",
            "What of TensorFlow Federated, you ask? Sall\n",
            "yesterday. Received the Bailey.\"\n",
            "\n",
            "\"Mr. Lorry, grimmering himself, or low varked thends the winter, and the eyes of Monsieur\n",
            "Defarge. \"Let his mind, hon in his\n",
            "life and message; four declare\n"
          ]
        }
      ],
      "source": [
        "# Text generation requires a batch_size=1 model.\n",
        "keras_model_batch1 = load_model(batch_size=1)\n",
        "print(generate_text(keras_model_batch1, 'What of TensorFlow Federated, you ask? '))"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "kKMUn-TlgxuP"
      },
      "source": [
        "## Shakespere のフェデレーテッドデータを読み込んで事前処理する\n",
        "\n",
        "`tff.simulation.datasets` パッケージには、\"clients\" に分割されたさまざまなデータセットが含まれます。各 client はフェデレーテッドラーニングに含まれる可能性のある特定のデバイス上のデータセットに対応しています。\n",
        "\n",
        "これらのデータセットは、実際の分散データでのトレーニングの課題をシミュレーションで再現する現実的な非 IID データ分布を示します。このデータの事前処理は、[Leaf project](https://arxiv.org/abs/1812.01097)（[github](https://github.com/TalwalkarLab/leaf)）のツールを使用して行われています。"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 8,
      "metadata": {
        "id": "di3nStTDg0qc"
      },
      "outputs": [],
      "source": [
        "train_data, test_data = tff.simulation.datasets.shakespeare.load_data()"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "_iiY65Vv4QNK"
      },
      "source": [
        "`shakespeare.load_data()` が提供するデータセットは、文字列 `Tensors` で構成されています。各行はシェイクスピア劇の登場人物のセリフです。client キーは、劇の名前と登場人物の名前を結合したもので、たとえば<br> `MUCH_ADO_ABOUT_NOTHING_OTHELLO` は「*Much Ado About Nothing*」という劇の登場人物オセロのセリフに対応しています。実勢のフェデレーテッドラーニングシナリオでは、client は ID で識別または追跡されることはありませんが、シミュレーションでは、キー付きのデータセットを使用する方が役に立ちます。\n",
        "\n",
        "ここでは、King Lear のデータを例とします。"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 9,
      "metadata": {
        "id": "FEKiy1ntmmnk"
      },
      "outputs": [
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "tf.Tensor(b'', shape=(), dtype=string)\n",
            "tf.Tensor(b'What?', shape=(), dtype=string)\n"
          ]
        }
      ],
      "source": [
        "# Here the play is \"The Tragedy of King Lear\" and the character is \"King\".\n",
        "raw_example_dataset = train_data.create_tf_dataset_for_client(\n",
        "    'THE_TRAGEDY_OF_KING_LEAR_KING')\n",
        "# To allow for future extensions, each entry x\n",
        "# is an OrderedDict with a single key 'snippets' which contains the text.\n",
        "for x in raw_example_dataset.take(2):\n",
        "  print(x['snippets'])"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "kUnbI5Hp4sXg"
      },
      "source": [
        "`tf.data.Dataset` 変換を使用して、このデータを上記で読み込んだ文字列 RNN のトレーニング用に準備します。\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 10,
      "metadata": {
        "id": "9kDkmGe-7No7"
      },
      "outputs": [],
      "source": [
        "# Input pre-processing parameters\n",
        "SEQ_LENGTH = 100\n",
        "BATCH_SIZE = 8\n",
        "BUFFER_SIZE = 10000  # For dataset shuffling"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 11,
      "metadata": {
        "id": "W95Of6Bwsrfc"
      },
      "outputs": [],
      "source": [
        "# Construct a lookup table to map string chars to indexes,\n",
        "# using the vocab loaded above:\n",
        "table = tf.lookup.StaticHashTable(\n",
        "    tf.lookup.KeyValueTensorInitializer(\n",
        "        keys=vocab, values=tf.constant(list(range(len(vocab))),\n",
        "                                       dtype=tf.int64)),\n",
        "    default_value=0)\n",
        "\n",
        "\n",
        "def to_ids(x):\n",
        "  s = tf.reshape(x['snippets'], shape=[1])\n",
        "  chars = tf.strings.bytes_split(s).values\n",
        "  ids = table.lookup(chars)\n",
        "  return ids\n",
        "\n",
        "\n",
        "def split_input_target(chunk):\n",
        "  input_text = tf.map_fn(lambda x: x[:-1], chunk)\n",
        "  target_text = tf.map_fn(lambda x: x[1:], chunk)\n",
        "  return (input_text, target_text)\n",
        "\n",
        "\n",
        "def preprocess(dataset):\n",
        "  return (\n",
        "      # Map ASCII chars to int64 indexes using the vocab\n",
        "      dataset.map(to_ids)\n",
        "      # Split into individual chars\n",
        "      .unbatch()\n",
        "      # Form example sequences of SEQ_LENGTH +1\n",
        "      .batch(SEQ_LENGTH + 1, drop_remainder=True)\n",
        "      # Shuffle and form minibatches\n",
        "      .shuffle(BUFFER_SIZE).batch(BATCH_SIZE, drop_remainder=True)\n",
        "      # And finally split into (input, target) tuples,\n",
        "      # each of length SEQ_LENGTH.\n",
        "      .map(split_input_target))"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "Jw98HnKmEhuh"
      },
      "source": [
        "元のシーケンスの形成と上記のバッチの形成では、`drop_remainder=True` を使って単純化しています。つまり、少なくともテキストの `(SEQ_LENGTH + 1) * BATCH_SIZE` 文字を持たない登場人物（client））のデータセットは空となります。この状況を解消するために使用される一般的なアプローチはバッチを特殊なトークンでパッドし、パディングトークンを考慮しないように損失量をマスクする方法です。\n",
        "\n",
        "これではサンプルが複雑化してしまうため、このチュートリアルでは[標準的なチュートリアル](https://www.tensorflow.org/tutorials/sequences/text_generation)と同様にフルバッチのみを使用します。ただし、多数のユーザーが小さなデータセットを持つことになるため、フェデレーテッドの設定ではこの問題はより明確に現れます。\n",
        "\n",
        "では、`raw_example_dataset` を事前処理し、型を確認しましょう。"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 12,
      "metadata": {
        "id": "7rTal7bksWwc"
      },
      "outputs": [
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "(TensorSpec(shape=(8, 100), dtype=tf.int64, name=None), TensorSpec(shape=(8, 100), dtype=tf.int64, name=None))\n"
          ]
        }
      ],
      "source": [
        "example_dataset = preprocess(raw_example_dataset)\n",
        "print(example_dataset.element_spec)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "ePT8Oawm8SRP"
      },
      "source": [
        "## モデルをコンパイルし、事前処理済みのデータでテストする"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "vEgDsz-48cAq"
      },
      "source": [
        "コンパイルされていない Keras モデルを読み込みましたが、`keras_model.evaluate` を実行するには、損失とメトリックとともにコンパイルする必要があります。また、オプティマイザにコンパイルし、フェデレーテッドラーニングのオンデバイスオプティマイザとして使用します。"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "RsuVZ5KMWnn8"
      },
      "source": [
        "元のチュートリアルには文字レベルの精度（最高確率が適切な次の文字に配置される予測の割合）が含まれていませんでした。これは便利なメトリックであるため、追加することにします。ただし、予測の階数は 3（`BATCH_SIZE * SEQ_LENGTH` の各予測に対するロジットのベクトル）であり、`SparseCategoricalAccuracy` は階数 2 の予測のみを期待するため、新しいメトリッククラスを定義する必要があります。"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 13,
      "metadata": {
        "id": "gOUiDBvmWlM9"
      },
      "outputs": [],
      "source": [
        "class FlattenedCategoricalAccuracy(tf.keras.metrics.SparseCategoricalAccuracy):\n",
        "\n",
        "  def __init__(self, name='accuracy', dtype=tf.float32):\n",
        "    super().__init__(name, dtype=dtype)\n",
        "\n",
        "  def update_state(self, y_true, y_pred, sample_weight=None):\n",
        "    y_true = tf.reshape(y_true, [-1, 1])\n",
        "    y_pred = tf.reshape(y_pred, [-1, len(vocab), 1])\n",
        "    return super().update_state(y_true, y_pred, sample_weight)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "U2X9eFgt94PM"
      },
      "source": [
        "これで、モデルをコンパイルし、`example_dataset` で評価できるようになりました。"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 14,
      "metadata": {
        "id": "c3Xd-52-9zGa"
      },
      "outputs": [
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "Downloading data from https://storage.googleapis.com/tff-models-public/dickens_rnn.batch8.kerasmodel\n",
            "16195584/16193984 [==============================] - 0s 0us/step\n",
            "16203776/16193984 [==============================] - 0s 0us/step\n",
            "Evaluating on an example Shakespeare character: 0.402750\n",
            "Expected accuracy for random guessing: 0.012\n",
            "Evaluating on completely random data: 0.013\n"
          ]
        }
      ],
      "source": [
        "BATCH_SIZE = 8  # The training and eval batch size for the rest of this tutorial.\n",
        "keras_model = load_model(batch_size=BATCH_SIZE)\n",
        "keras_model.compile(\n",
        "    loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),\n",
        "    metrics=[FlattenedCategoricalAccuracy()])\n",
        "\n",
        "# Confirm that loss is much lower on Shakespeare than on random data\n",
        "loss, accuracy = keras_model.evaluate(example_dataset.take(5), verbose=0)\n",
        "print(\n",
        "    'Evaluating on an example Shakespeare character: {a:3f}'.format(a=accuracy))\n",
        "\n",
        "# As a sanity check, we can construct some completely random data, where we expect\n",
        "# the accuracy to be essentially random:\n",
        "random_guessed_accuracy = 1.0 / len(vocab)\n",
        "print('Expected accuracy for random guessing: {a:.3f}'.format(\n",
        "    a=random_guessed_accuracy))\n",
        "random_indexes = np.random.randint(\n",
        "    low=0, high=len(vocab), size=1 * BATCH_SIZE * (SEQ_LENGTH + 1))\n",
        "data = collections.OrderedDict(\n",
        "    snippets=tf.constant(\n",
        "        ''.join(np.array(vocab)[random_indexes]), shape=[1, 1]))\n",
        "random_dataset = preprocess(tf.data.Dataset.from_tensor_slices(data))\n",
        "loss, accuracy = keras_model.evaluate(random_dataset, steps=10, verbose=0)\n",
        "print('Evaluating on completely random data: {a:.3f}'.format(a=accuracy))"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "lH0WzL5L8Lm4"
      },
      "source": [
        "## フェデレーテッドラーニングでモデルを微調整する"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "NCao4M3L_tsA"
      },
      "source": [
        "TFF はすべての TensorFlow 計算をシリアル化するため、非 Python 環境で実行することが可能です（現時点では、Python で実装されたシミュレーションランタイムのみを利用できます）。Eager モードで実行してはいますが（TF 2.0）、現時点では、TFF は \"`with tf.Graph.as_default()`\" 文のコンテキスト内に必要な演算を作成して、TensorFlow 計算をシリアル化しています。したがって、モデルを関数が制御するグラフに導入するために TFF が使用できる関数を指定する必要があります。これを次のようにして行います。"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 15,
      "metadata": {
        "id": "5KadIvFp7m6y"
      },
      "outputs": [],
      "source": [
        "# Clone the keras_model inside `create_tff_model()`, which TFF will\n",
        "# call to produce a new copy of the model inside the graph that it will \n",
        "# serialize. Note: we want to construct all the necessary objects we'll need \n",
        "# _inside_ this method.\n",
        "def create_tff_model():\n",
        "  # TFF uses an `input_spec` so it knows the types and shapes\n",
        "  # that your model expects.\n",
        "  input_spec = example_dataset.element_spec\n",
        "  keras_model_clone = tf.keras.models.clone_model(keras_model)\n",
        "  return tff.learning.from_keras_model(\n",
        "      keras_model_clone,\n",
        "      input_spec=input_spec,\n",
        "      loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),\n",
        "      metrics=[FlattenedCategoricalAccuracy()])"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "ZJF_yhJxAi2l"
      },
      "source": [
        "これで、フェデレーテッドアベレージングのイテレーション処理を構築する準備が整いました。これをモデルの改善に使用します（フェデレーテッドアベレージングアルゴリズムの詳細は、論文「[Communication-Efficient Learning of Deep Networks from Decentralized Data](https://arxiv.org/abs/1602.05629)」をご覧ください）。\n",
        "\n",
        "コンパイルされた Keras モデルを使用して、フェデレーテッドトレーニングの各ラウンドの後に標準的な（非フェデレーテッド）評価を実行します。これは、シミュレーションされたフェデレーテッドラーニングを行う場合で標準的なテストデータセットがある場合の研究目的に役立ちます。\n",
        "\n",
        "現実的な実稼働環境では、これと同じテクニックを使用してフェデレーテッドラーニングでモデルをトレーニングし、テストや QA を行えるように分散ベンチマークデータセットで評価します。"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 16,
      "metadata": {
        "id": "my3PW3qhAMDA"
      },
      "outputs": [],
      "source": [
        "# This command builds all the TensorFlow graphs and serializes them: \n",
        "fed_avg = tff.learning.build_federated_averaging_process(\n",
        "    model_fn=create_tff_model,\n",
        "    client_optimizer_fn=lambda: tf.keras.optimizers.SGD(lr=0.5))"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "qVOkzs9C9kmv"
      },
      "source": [
        "次は最も単純なループで、1 つのバッチの単一の client における 1 つのラウンドで、フェデレーテッドアベレージングを実行します。"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 17,
      "metadata": {
        "id": "lrjUrkjq9jYk"
      },
      "outputs": [
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "loss=4.403, accuracy=0.132\n"
          ]
        }
      ],
      "source": [
        "state = fed_avg.initialize()\n",
        "state, metrics = fed_avg.next(state, [example_dataset.take(5)])\n",
        "train_metrics = metrics['train']\n",
        "print('loss={l:.3f}, accuracy={a:.3f}'.format(\n",
        "    l=train_metrics['loss'], a=train_metrics['accuracy']))"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "o2CjvVg0FZpS"
      },
      "source": [
        "では、もう少し興味深いトレーニングと評価ループを記述してみましょう。\n",
        "\n",
        "このシミュレーションは比較的素早く実行します。各ラウンドで同一の 3 つの client でトレーニングしますが、それぞれで 2 つのミニバッチのみを考慮しています。\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 18,
      "metadata": {
        "id": "wE386-rbMCve"
      },
      "outputs": [],
      "source": [
        "def data(client, source=train_data):\n",
        "  return preprocess(source.create_tf_dataset_for_client(client)).take(5)\n",
        "\n",
        "\n",
        "clients = [\n",
        "    'ALL_S_WELL_THAT_ENDS_WELL_CELIA', 'MUCH_ADO_ABOUT_NOTHING_OTHELLO',\n",
        "]\n",
        "\n",
        "train_datasets = [data(client) for client in clients]\n",
        "\n",
        "# We concatenate the test datasets for evaluation with Keras by creating a \n",
        "# Dataset of Datasets, and then identity flat mapping across all the examples.\n",
        "test_dataset = tf.data.Dataset.from_tensor_slices(\n",
        "    [data(client, test_data) for client in clients]).flat_map(lambda x: x)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "cU3FuY00MOoX"
      },
      "source": [
        "`fed_avg.initialize()` で生成されるモデルの最初の状態は、読み込まれた重みではなく、Keras モデルのランダムなイニシャライザに基づきます。`clone_model()` は重みを複製しないためです。トレーニング済みのモデルからトレーニングを始めるには、読み込んだモデルから直接、サーバー状態のモデルの重みを設定します。"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 22,
      "metadata": {
        "id": "vm_-PU8OFXpY"
      },
      "outputs": [
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "Round 0\n",
            "\tEval: loss=3.372, accuracy=0.395\n",
            "\tTrain: loss=4.317, accuracy=0.083\n",
            "Round 1\n",
            "\tEval: loss=4.300, accuracy=0.129\n",
            "\tTrain: loss=4.172, accuracy=0.184\n",
            "Round 2\n",
            "\tEval: loss=4.152, accuracy=0.201\n",
            "\tTrain: loss=4.077, accuracy=0.191\n",
            "Round 3\n",
            "\tEval: loss=4.031, accuracy=0.189\n",
            "\tTrain: loss=3.965, accuracy=0.192\n",
            "Round 4\n",
            "\tEval: loss=3.946, accuracy=0.183\n",
            "\tTrain: loss=3.877, accuracy=0.196\n",
            "\tEval: loss=3.885, accuracy=0.168\n"
          ]
        }
      ],
      "source": [
        "NUM_ROUNDS = 5\n",
        "\n",
        "# The state of the FL server, containing the model and optimization state.\n",
        "state = fed_avg.initialize()\n",
        "\n",
        "# Load our pre-trained Keras model weights into the global model state.\n",
        "state = tff.learning.state_with_new_model_weights(\n",
        "    state,\n",
        "    trainable_weights=[v.numpy() for v in keras_model.trainable_weights],\n",
        "    non_trainable_weights=[\n",
        "        v.numpy() for v in keras_model.non_trainable_weights\n",
        "    ])\n",
        "\n",
        "\n",
        "def keras_evaluate(state, round_num):\n",
        "  # Take our global model weights and push them back into a Keras model to\n",
        "  # use its standard `.evaluate()` method.\n",
        "  keras_model = load_model(batch_size=BATCH_SIZE)\n",
        "  keras_model.compile(\n",
        "      loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),\n",
        "      metrics=[FlattenedCategoricalAccuracy()])\n",
        "  state.model.assign_weights_to(keras_model)\n",
        "  loss, accuracy = keras_model.evaluate(example_dataset, steps=2, verbose=0)\n",
        "  print('\\tEval: loss={l:.3f}, accuracy={a:.3f}'.format(l=loss, a=accuracy))\n",
        "\n",
        "\n",
        "for round_num in range(NUM_ROUNDS):\n",
        "  print('Round {r}'.format(r=round_num))\n",
        "  keras_evaluate(state, round_num)\n",
        "  state, metrics = fed_avg.next(state, train_datasets)\n",
        "  train_metrics = metrics['train']\n",
        "  print('\\tTrain: loss={l:.3f}, accuracy={a:.3f}'.format(\n",
        "      l=train_metrics['loss'], a=train_metrics['accuracy']))\n",
        "\n",
        "print('Final evaluation')\n",
        "keras_evaluate(state, NUM_ROUNDS + 1)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "SoshvcHhXVa6"
      },
      "source": [
        "デフォルトの変更により、大きな違いを得るほどのトレーニングはまだ行われていませんが、より長時間、より多くの Shakespeare データをトレーニングする場合、更新したモデルに生成されるテキストのスタイルに違いがみられるようになります。"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 23,
      "metadata": {
        "id": "NTUig7QmXavy"
      },
      "outputs": [
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "What of TensorFlow Federated, you ask? Shalways, I will call your\n",
            "compet with any city brought their faces uncompany,\" besumed him. \"When he\n",
            "sticked Madame Defarge pushed the lamps.\n",
            "\n",
            "\"Have I often but no unison. She had probably come,\n"
          ]
        }
      ],
      "source": [
        "# Set our newly trained weights back in the originally created model.\n",
        "keras_model_batch1.set_weights([v.numpy() for v in keras_model.weights])\n",
        "# Text generation requires batch_size=1\n",
        "print(generate_text(keras_model_batch1, 'What of TensorFlow Federated, you ask? '))"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "4DA1Fkf5mN0s"
      },
      "source": [
        "## 推奨される拡張\n",
        "\n",
        "このチュートリアルは導入ステップにしかすぎません！次に、このノートブックを拡張するためのアイデアをいくつか示しています。\n",
        "\n",
        "- トレーニングする client をランダムにサンプリングするより現実的なトレーニングループを記述する。\n",
        "- client データセットに \"`.repeat(NUM_EPOCHS)`\" を使用して、ローカルトレーニングの複数のエポックを試してみる（[McMahan et. al.](https://arxiv.org/abs/1602.05629) で示す例）。これを行っている[画像分類のフェデレーテッドラーニング](federated_learning_for_image_classification.ipynb)もご覧ください。\n",
        "- `compile()` コマンドを変更して、client でさまざまな最適化アルゴリズムを使った実験を行う。\n",
        "- `build_federated_averaging_process` に `server_optimizer` 属性を使用し、サーバー上にモデルの更新を適用するためのさまざまなアルゴリズムを試してみる。\n",
        "- `build_federated_averaging_process` に `client_weight_fn` 属性を使用して、client のさまざまな重みづけを試してみる。デフォルトは、client のサンプル数で client の更新を重みづけしますが、`client_weight_fn=lambda _: tf.constant(1.0)` などのように行うことができます。"
      ]
    }
  ],
  "metadata": {
    "accelerator": "GPU",
    "colab": {
      "collapsed_sections": [],
      "name": "federated_learning_for_text_generation.ipynb",
      "toc_visible": true
    },
    "kernelspec": {
      "display_name": "Python 3",
      "name": "python3"
    }
  },
  "nbformat": 4,
  "nbformat_minor": 0
}
